{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "39c9708b",
   "metadata": {},
   "source": [
    "# Grad-CAM\n",
    "\n",
    "### 参数说明\n",
    "\n",
    "1. sample_dir： 你自己的样本目录\n",
    "2. model_root：你自己的模型目录，注意这里不需要精确到viz目录，所有的Grad-CAM图会生成的此目录的Grad-CAM文件夹\n",
    "3. target_layer：你自己喜欢的层的名称，如果不知道具体的参数名字，可以先运行一次，在输入修改。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a412b649",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from onekey_algo import get_param_in_cwd\n",
    "from onekey_algo.custom.components.comp2 import target_layer_mapping\n",
    "\n",
    "sample_dir = get_param_in_cwd('data_pattern')\n",
    "model_root = os.path.join(get_param_in_cwd('model_root'), get_param_in_cwd('model_name'))\n",
    "target_layer = target_layer_mapping[get_param_in_cwd('model_name') + '_2D']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9eb2b98",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from onekey_algo.datasets.image_loader import default_loader\n",
    "from onekey_algo.custom.components.comp2 import show_cam_on_image\n",
    "import torch\n",
    "import os\n",
    "import random\n",
    "from onekey_algo import get_param_in_cwd\n",
    "from onekey_algo.custom.components.comp2 import extract, init_from_model, init_from_onekey\n",
    "from onekey_algo.utils.MultiProcess import MultiProcess\n",
    "import numpy as np\n",
    "os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'\n",
    "\n",
    "import monai\n",
    "from glob import glob\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import matplotlib\n",
    "matplotlib.use('Agg')\n",
    "from onekey_algo import get_param_in_cwd\n",
    "from onekey_algo.custom.components.comp2 import target_layer_mapping\n",
    "from onekey_algo.datasets.image_loader import default_loader\n",
    "from onekey_algo.custom.components.comp2 import show_cam_on_image\n",
    "from onekey_algo.custom.components.comp2 import extract, init_from_model, init_from_onekey\n",
    "from onekey_algo.utils.MultiProcess import MultiProcess\n",
    "\n",
    "\n",
    "def viz_sample(samples, thread_id):\n",
    "    model, transformer, device = init_from_onekey(os.path.join(model_root, 'viz'))\n",
    "#     for n, m in model.named_modules():\n",
    "#         print('Feature name:', n, \"|| Module:\", m)\n",
    "    gradcam = monai.visualize.GradCAM(nn_module=model, target_layers=target_layer)\n",
    "    viz_dir = os.path.join(model_root, 'Grad-CAM')\n",
    "    os.makedirs(viz_dir, exist_ok=True)\n",
    "    for sample in samples:\n",
    "        if os.path.exists(os.path.join(viz_dir, os.path.basename(sample))):\n",
    "            continue\n",
    "        img = default_loader(sample)\n",
    "        sample_ = transformer(img)\n",
    "        sample_  = sample_.view(1, *sample_.size()).to(device)\n",
    "        res_cam = gradcam(x=sample_, class_idx=None)\n",
    "        fig, axes = plt.subplots(1, 3, figsize=(12, 4), facecolor='white')\n",
    "    #     axes[0].imshow(-res_cam[0][0].cpu(), cmap='jet')\n",
    "        axes[0].imshow(img.resize(sample_.size()[2:]))\n",
    "        axes[0].axis('off')\n",
    "    #     plt.savefig(f\"viz/{os.path.basename(sample).replace('.png', '_se.png')}\", bbox_inches = 'tight')\n",
    "    #     plt.show()\n",
    "    #     plt.figure(figsize=(10, 10))\n",
    "    #     plt.axis('off')\n",
    "        imshow = axes[1].imshow(-res_cam[0][0].cpu(),cmap='jet')\n",
    "        axes[1].axis('off')\n",
    "        imshow = axes[2].imshow(show_cam_on_image(img.resize(sample_.size()[2:]), -res_cam[0][0].cpu(), use_rgb=True, reverse=False), \n",
    "                                cmap='jet')\n",
    "        axes[2].axis('off')\n",
    "        cax = fig.add_axes([0.92, 0.17, 0.02, axes[2].get_position().height]) \n",
    "        plt.colorbar(imshow, cax=cax)\n",
    "        plt.savefig(f'{viz_dir}/{os.path.basename(sample).replace(\".npy\", \".png\")}', bbox_inches = 'tight')\n",
    "        plt.show()\n",
    "        plt.close(fig)\n",
    "\n",
    "for model_name in get_param_in_cwd('model_names'):\n",
    "    sample_dir = get_param_in_cwd('data_pattern2d')\n",
    "    model_root = os.path.join(get_param_in_cwd('model_root'), '2D', model_name)\n",
    "    target_layer = target_layer_mapping[model_name + '_2D']\n",
    "    samples = glob(os.path.join(sample_dir, '*'))\n",
    "    random.shuffle(samples)\n",
    "    viz_sample(samples, thread_id=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f86e48e2",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
